前面文章在說明 Sequence 建構實例時,將layer物件陣列以建構子參數傳入後,Sequence 會自動將 layer物件陣列裡之物件逐一加入到keras.engine.sequential.Sequential.layers 與 keras.engine.sequential.Sequential._self_tracked_trackables 二個list集合中。
如果額外使用 keras.engine.sequential.Sequential.Add 加入,可以在什麼時機點加入才有效?
因為我們由之前model fit 文章得知,如果 keras.engine.sequential.Sequential 實例之模型 未執行 model build,會於model fit 做初始化的 model build 動作,也就是執行 keras.engine.sequential._build_graph_network_for_inferred_shape函式 與 keras.engine.functional._init_graph_network 函式 來為模型確保對 layer層逐一建構過。 所以確定的是如果沒做過 model build,至少fit會偵測到而補做。
但如果於 model fit 之前,就執行 model build 函式,keras.engine.sequential.Sequential.Add 什麼時候執行才算有效?
其實只要在 model fit 之前任何時機點(當然要完成Sequential實體之建立)都可以有效加入。當然model build之前加入一定沒問題。model build之後也沒問題嗎? 是的。
原因是keras.engine.sequential.Sequential.Add 在最後階段,如果發現模型已經有build過,keras.engine.sequential.Sequential._graph_initialized 因為有執行過keras.engine.functional._init_graph_network而被設為True,代表已經將之前的每個layer共同構成的topology之連結找出,只要透過更新最後的output layer為新加入的layer 物件,再執行一次keras.engine.functional._init_graph_network(inputs, outputs),即可再建構出新的topology連結,並重新設定keras.engine.sequential._self_tracked_trackables 與 keras.engine.training.Model.layers。
範例如下:
from tensorflow import keras
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
model = Sequential([
layers.Dense(512, activation="relu")
])
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.build(input_shape=(None, 784))
model.add(layers.Dense(10, activation="softmax"))
model.add(layers.Dense(20, activation="relu"))
print(model.layers[0])
print(model.layers[1])
print(model.layers[2])
print(model._self_tracked_trackables[0])
print(model._self_tracked_trackables[1])
print(model._self_tracked_trackables[2])
print(model._self_tracked_trackables[3])
model.summary()